import torch

# 输入图像（时域）
x = torch.tensor([[1.0, 2.0, 3.0],
                  [4.0, 5.0, 6.0],
                  [7.0, 8.0, 9.0]])

# 完整频谱（fft2 输出）
x_fft = torch.fft.fft2(x)
print("完整频谱 (fft2 输出):")
print(x_fft)

# 只保留正频率部分（rfft2 输出）
x_rfft = torch.fft.rfft2(x)
print("\n正频率部分 (rfft2 输出):")
print(x_rfft)